package org.zjvis.datascience.service.dataprovider;

import java.sql.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.model.AggrConfig;
import org.zjvis.datascience.common.model.Table;
import com.google.common.collect.Lists;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.util.db.JDBCUtil;

/**
 * @description GP数据provider 提供数据库连接池
 * @date 2021-11-02
 */
@Service("gpDataProvider")
public class GPDataProvider extends JdbcDataProvider {

    @Override
    public List<String> showDatabases(Long dsId) throws DataScienceException {
        List<String> ret = Lists.newArrayList();
        Connection conn = null;
        try {
            conn = getConn(dsId);
            Statement stat = conn.createStatement();
            stat.setQueryTimeout(SQL_QUERY_TIMEOUT_SECOND);
            ResultSet rs = stat.executeQuery("select datname from pg_database");
            while (rs.next()) {
                ret.add(rs.getString(1));
            }
        } catch (Exception e) {
            logger.error("show databases error, dsId=" + dsId, e);
            throw new DataScienceException("dsId=" + dsId + "\n" + e.getMessage());
        } finally {
            if (conn != null) {
                try {
                    conn.close();
                } catch (SQLException e) {
                }
            }
        }
        return ret;
    }

    @Override
    public List<String> showTables(Long dsId, String db) throws DataScienceException {
        List<String> ret = Lists.newArrayList();
        Connection conn = null;
        try {
            conn = getConn(dsId);
            Statement stat = conn.createStatement();
            stat.setQueryTimeout(SQL_QUERY_TIMEOUT_SECOND);
//          stat.execute("use " + db);
            ResultSet rs = stat.executeQuery("select table_name from information_schema.tables");
            while (rs.next()) {
                ret.add(rs.getString(1));
            }
        } catch (Exception e) {
            logger.error("show tables error, dsId=" + dsId, e);
            throw new DataScienceException("dsId=" + dsId + "\n" + e.getMessage());
        } finally {
            if (conn != null) {
                try {
                    conn.close();
                } catch (SQLException e) {
                }
            }
        }
        return ret;
    }


    public Boolean checkTableIfExist(Table table) {
        String sql = "SELECT COUNT(1) FROM " + table.getName();

        Connection conn = null;
        try {
            conn = getConn(table.getDsId());
            Statement stat = conn.createStatement();
            stat.executeQuery(sql);
            return true;
        } catch (Exception e) {
            return false;
        } finally {
            if (conn != null) {
                try {
                    conn.close();
                } catch (SQLException e) {
                }
            }
        }
    }

    public String getSql(Table table, AggrConfig config) throws DataScienceException {
        String sql = super.getSql(table, config);
        return sql.replaceAll("`", "\"");
    }

    public void executeSql(String sql) {
        Connection conn = null;
        try {
            conn = getConn(1L);
            Statement st = conn.createStatement();
            executeSql(st, sql);
        } catch (Exception e) {
            logger.error(e.getMessage());
        } finally {
            JDBCUtil.close(conn, null, null);
        }
    }

    public void executeSql(Statement st, String sql) throws SQLException {
        st.execute(sql);
    }

    public JSONArray executeQuery(String selectSql) {
        JSONArray sqlData = new JSONArray();
        Connection conn = null;
        try {
            conn = getConn(1L);
            Statement st = conn.createStatement();
            sqlData = executeQuery(st, selectSql);
        } catch (Exception e) {
            logger.error(e.getMessage());
        } finally {
            JDBCUtil.close(conn, null, null);
        }
        return sqlData;
    }

    public JSONArray executeQuery(Statement st, String selectSql) throws SQLException {
        ResultSet rs = st.executeQuery(selectSql);
        ResultSetMetaData meta = rs.getMetaData();

        JSONArray values = new JSONArray();
        while (rs.next()) {
            JSONObject column = new JSONObject();
            for (int i = 1; i < meta.getColumnCount() + 1; i++) {
                String name = meta.getColumnName(i);
                column.put(name, rs.getString(name));
            }
            values.add(column);
        }

        return values;
    }

    public Set<String> executeQueryAsOneSet(String selectSql) {
        Set<String> sqlData = new HashSet<>();
        Connection conn = null;
        try {
            conn = getConn(1L);
            Statement st = conn.createStatement();
            sqlData = executeQueryAsOneSet(st, selectSql);
        } catch (Exception e) {
            logger.error(e.getMessage());
        } finally {
            JDBCUtil.close(conn, null, null);
        }
        return sqlData;
    }

    public Set<String> executeQueryAsOneSet(Statement st, String selectSql) throws SQLException {
        Set<String> sqlData = new HashSet<>();
        ResultSet rs = st.executeQuery(selectSql);
        ResultSetMetaData meta = rs.getMetaData();

        while (rs.next()) {
            for (int i = 1; i < meta.getColumnCount() + 1; i++) {
                if (rs.getString(meta.getColumnName(i)) != null) {
                    String str = rs.getString(meta.getColumnName(i));
                    if (str != null && !str.trim().isEmpty()) {
                        sqlData.add(str);
                    }
                }
            }
        }
        return sqlData;
    }

    public List<String> getDataFromJSON(JSONArray values, String name, int num) {
        List<String> ret = new ArrayList<>();
        int count = 0;
        for (int i = 0; i < values.size(); i++) {
            String data = values.getJSONObject(i).getString(name);
            try {
                if (data != null) {
                    ret.add(data);
                    count += 1;
                    if (num > 0 && count > num - 1) {
                        return ret;
                    }
                }
            } catch (Exception e) {
                logger.debug(e.getMessage());
            }
        }
        return ret;
    }

}
